from utils import *
import sys
from math import inf
import time
from tabulate import tabulate
from sklearn.cluster import KMeans
from math import isnan

m = 10
K = 10
M = 1
a = 3.0
b = 3.0
# low = 5.0
# high = 10.0
low = 5.0
high = 8.0
mu0 = generate_mu(m, low, high, a, b)
mu1 = generate_mu(m, low, high, a, b)
mu2 = generate_mu(m, low, high, a, b)
mu3 = generate_mu(m, low, high, a, b)
log = {'mu0' : mu0, 'mu1' : mu1, 'mu2' : mu2, 'mu3' : mu3}
N_tot = 500
var = 3.0
# Generate attacker data
theta_1, theta_2 = generate_theta_normal(mu0, var * np.eye(m), mu1, var * np.eye(m), mu2, var * np.eye(m), mu3, var * np.eye(m), N_tot)
full_theta = np.concatenate([theta_1, theta_2], axis=2)
# print(full_theta)
print("Mu0 : ", mu0)
print("Mu1 : ", mu1)
print("Mu2 : ", mu2)
print("Mu3 : ", mu3)

print(full_theta.shape)

xi = 1e6
# A, b, C, d, tL, tU = compute_params(m, K, N_tot, numerator, denominator, full_theta)
w = np.ones(N_tot + 1)
# w[-1] = 0
# tot_z_dro = FCP_DRO(m, K, N_tot, N_tot, M, tL, tU, A, b, C, d, w, xi)
# # values_DRO = FCP_values(tot_z_dro, full_theta)
# # opt_tot = np.mean(values_DRO)
# opt_tot = utility_robust(full_theta, numerator, denominator, m, tot_z_dro, N_tot, xi, w)

tot_theta = np.array(full_theta)

p_i = M_simplex_projection(M, torch.rand(m))
w_i = torch.ones(N_tot) / N_tot
p = SSG_TTGD(p_i, w_i, N_tot, numerator, denominator, tot_theta, m, M, xi, num_epochs=500, lr_p=0.001, lr_w=0.001)
TT_tot_GD = utility_robust(full_theta, numerator, denominator, m, p, N_tot, xi, w, 1)
while (math.isnan(TT_tot_GD)):
    p_i = M_simplex_projection(M, torch.rand(m))
    w_i = torch.ones(N_tot) / N_tot
    p = SSG_TTGD(p_i, w_i, N_tot, numerator, denominator, tot_theta, m, M, xi, num_epochs=500, lr_p=0.001, lr_w=0.001)
    TT_tot_GD = utility_robust(full_theta, numerator, denominator, m, p, N_tot, xi, w, 1)
    
print(TT_tot_GD)

p_i = M_simplex_projection(M, torch.rand(m))
p = SSG_gradient_descent(p_i, N_tot, numerator, denominator, tot_theta, m, M, w, N_tot, xi, 500, 0.001)
opt_tot_GD = utility_robust(full_theta, numerator, denominator, m, p, N_tot, xi, w, 1)

while(math.isnan(opt_tot_GD)):
    p_i = M_simplex_projection(M, torch.rand(m))
    p = SSG_gradient_descent(p_i, N_tot, numerator, denominator, tot_theta, m, M, w, N_tot, xi, 500, 0.001)
    opt_tot_GD = utility_robust(full_theta, numerator, denominator, m, p, N_tot, xi, w, 1)
print(TT_tot_GD)
print(opt_tot_GD)

tot_theta = tot_theta.reshape(N_tot,-1)
N = 50
kmeans = KMeans(n_clusters=N).fit(tot_theta)
cluster_cent = kmeans.cluster_centers_
cluster_cent = cluster_cent.reshape(N, m, 4)
Y = kmeans.predict(tot_theta)
s = np.zeros(N+1)
for i in range(len(Y)):
    s[Y[i]] += 1
A, b, C, d, tL, tU = compute_params(m, K, N, numerator, denominator, cluster_cent)
z_dro = FCP_DRO(m, K, N, N_tot, M, tL, tU, A, b, C, d, s, xi)
opt_tot_method = utility_robust(full_theta, numerator, denominator, m, z_dro, N_tot, xi, w)
print(TT_tot_GD)
print(opt_tot_GD)
print(opt_tot_method)

samp = 5
num_clusters = int(N/samp)
kmeans = KMeans(n_clusters=num_clusters).fit(tot_theta)
Y = kmeans.predict(tot_theta)
s = np.zeros(num_clusters+1)
strata = [[] for i in range(num_clusters)]
data_points = tot_theta.reshape(N_tot, m, 4)
for i in range(len(Y)):
    s[Y[i]] += 1
    strata[Y[i]].append(data_points[i])
best = 0
for rep in range(10):
    ls = []
    sampled_points = []
    for i in range(num_clusters):
        temp = np.random.choice(np.arange(s[i]), samp, replace=False)
        print(temp)
        sampled_points.extend([strata[i][int(temp[j])] for j in range(samp)])
        temp_ls = [s[i] / samp for j in range(samp)]
        ls.extend(temp_ls)
    ls.append(0)
    ls = np.array(ls)
    A, b, C, d, tL, tU = compute_params(m, K, N, numerator, denominator, sampled_points)
    z_dro = FCP_DRO(m, K, N, N_tot, M, tL, tU, A, b, C, d, ls, xi)
    opt = utility_robust(full_theta, numerator, denominator, m, z_dro, N_tot, xi, w)
    if (opt > best):
        best = opt

print(TT_tot_GD)
print(opt_tot_GD)
print(opt_tot_method)
print(best)
